#coding=utf-8
# Copyright (c) 2016 Tinydot. inc.
# All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import sys
import os

from ml_feature.util import get_logger, get_app_name, AUTHOR, TENSORBOARD_DIR, bytes_to_human
from PyQt5.QtCore import QUrl, pyqtSlot, pyqtSignal, QThread
from PyQt5 import QtCore
from PyQt5.QtWebChannel import QWebChannel
from PyQt5.QtWidgets import QApplication, QMainWindow, QAction, qApp, QMessageBox, QFileDialog, QMenu, QDesktopWidget
from PyQt5.QtWebEngineWidgets import QWebEngineView, QWebEnginePage
from ml_feature.core.libs import load_pb_graph, make_tensor_html_file, get_placehoder_prevalue
from ml_feature.util.QtDialog import Transfer_format, ProgressBarDialog, Save_as
import json


class LoadFileThread(QThread):
    """
    将模型文件加载为graph和graph_def，之后交给主线程设置Placeholder的预设值
    """
    setvalue = pyqtSignal()
    setProgressValue = pyqtSignal(str,int)
    def __init__(self, self_share):
        super(LoadFileThread, self).__init__()
        self.self_share = self_share

    def run(self):
        self.setProgressValue.emit('load_graph_model',0)
        self.self_share.graph_def, self.self_share.placeholder_valuedict = load_pb_graph(self.self_share.file_name,emit=self.setProgressValue.emit,emit_value=0)
        self.setProgressValue.emit('setPlaceholder prevalue',4)
        self.setvalue.emit()


class GenerateHtmlThread(QThread):
    """
    将已经处理好的graph_def和graph提取特征，并解析为html代码
    """
    done = pyqtSignal(str)
    setProgressValue = pyqtSignal(str,int)
    def __init__(self, self_share):
        super(GenerateHtmlThread, self).__init__()
        self.self_share = self_share

    def run(self):
        tmp_graph_html, self.self_share.sequence = make_tensor_html_file(self.self_share.graph, self.self_share.graph_def, TENSORBOARD_DIR,
                                               placeholder_valuedict=self.self_share.placeholder_valuedict,
                                               pre_html_data="<b>Model: %s (%s)</b><hr />"%(
                                                   os.path.basename(self.self_share.file_name),
                                                   bytes_to_human(os.path.getsize(self.self_share.file_name))
                                               ),
                                               post_html_data="<hr />Copyright: 2017-2018",
                                               emit=self.setProgressValue.emit,emit_value=4)
        self.setProgressValue.emit('generate html complete',9)
        self.done.emit(tmp_graph_html)


class MLCMainWindow(QMainWindow):

    def __init__(self, *args, **kwargs):
        """
        主窗口初始化，包括主窗口大小，按钮及其触发事件
        :param args:
        :param kwargs:
        """
        self.logger = get_logger(type(self).__name__)
        QMainWindow.__init__(self, *args, **kwargs)

        self.setWindowTitle(get_app_name())
        self.statuts_message("initialization wait ... ")

        # set window size
        sc_size = QDesktopWidget().screenGeometry(-1)
        size_x = sc_size.width()
        size_y = sc_size.height()

        w_size_x = int(size_x*0.9)
        w_size_y = int(size_y*0.7)
        w_x = int((size_x - w_size_x)/2)
        w_y = int((size_y - w_size_y)/2)

        self.setGeometry(w_x, w_y, w_size_x, w_size_y)
        # creat items
        self.create_actions()
        self.create_menus()
        self.setEnabled(True)
        self.create_views()

        # ready to show
        self.statuts_message("ready")
        self.show()

    def statuts_message(self, msg):
        """
        设置窗口提示消息
        :param msg: 希望显示在主窗口的提示消息
        :return:
        """
        self.statusBar().showMessage(msg)

    def create_actions(self):
        """
        初始化事件，打开文件，导出伪指令文件，另存为，退出等事件
        :return:
        """
        # exit
        self.action_exit = QAction(self.tr("&Exit"), self)
        self.action_exit.setStatusTip(self.tr("Exit application"))
        self.action_exit.triggered.connect(qApp.quit)

        # open
        self.action_open = QAction(self.tr("&Open"), self)
        self.action_open.setStatusTip(self.tr("Open model file"))
        self.action_open.triggered.connect(self.handler_action_open_file)

        # about
        self.action_about = QAction(self.tr("&About"), self)
        self.action_about.setStatusTip(self.tr("show about information"))
        self.action_about.triggered.connect(self.handler_action_about)

        # export
        self.action_transfer_format = QAction(self.tr("&export"), self)
        self.action_transfer_format.setStatusTip(self.tr("export graph as mlc file"))
        self.action_transfer_format.triggered.connect(self.hanlder_action_transfer_format)

        # save as
        self.action_save_as = QAction(self.tr("&save as"), self)
        self.action_save_as.setStatusTip(self.tr("save graph as pb or pbtxt file"))
        self.action_save_as.triggered.connect(self.handler_action_save_as)

    def create_menus(self):
        """
        创建目录，并和已创建的事件进行绑定
        :return:
        """
        self.menu_bar = self.menuBar()
        self.menu_bar.setNativeMenuBar(False)

        self.menu_file = self.menu_bar.addMenu("&File")
        self.menu_file.addAction(self.action_open)
        self.menu_file.addAction(self.action_transfer_format)
        self.menu_file.addAction(self.action_save_as)
        self.menu_file.addAction(self.action_exit)
        self.menu_file.addAction(self.action_about)

    def handler_action_open_file(self):
        """
        点击open时触发此事件，读取pb或pbtxt文件，并展示在主窗口上
        :return:
        """
        ops = QFileDialog.Options()
        ops |= QFileDialog.DontUseNativeDialog
        self.file_name = QFileDialog.getOpenFileName(self, caption="open model protobuf format file",
                                                filter="*.pb *.pbtxt", options=ops)[0]
        if not os.path.isfile(self.file_name):
            msg = QMessageBox()
            msg.setText("Error: %s not find"%self.file_name)
            msg.exec_()
            return
        self.view_html.load(QUrl(''))
        self.loadfile_thread = LoadFileThread(self)
        self.loadfile_thread.setvalue.connect(self.setvalue)
        self.loadfile_thread.setProgressValue.connect(self.setProgress)
        self.progressDialog = ProgressBarDialog(10,parent=self)
        self.loadfile_thread.start()
        self.progressDialog.show()
        self.loadfile_thread.exec_()
        self.progressDialog.exec_()

    def hanlder_action_transfer_format(self):
        """
        点击export触发此事件，将已经打开的文件转为伪指令文件
        :return:
        """
        try:
            type(self.graph_def)
        except:
            self.statuts_message('Please open a model protobuf file')
        else:
            dialog = Transfer_format(graph_def=self.graph_def, sequence=self.sequence, parent=self)
            result = dialog.exec_()
            trans_flag = dialog.getTransFlag()
            if trans_flag:
                self.statuts_message('save succeed')
            else:
                self.statuts_message('save failed')
            dialog.destroy()


    def handler_action_save_as(self):
        """
        点击save as触发此事件，将graph_def转化为pb或者pbtxt
        :return:
        """
        try:
            type(self.graph_def)
        except:
            self.statuts_message('Please open a model protobuf file')
        else:
            dialog = Save_as(graph_def=self.graph_def, parent=self)
            result = dialog.exec_()
            trans_flag = dialog.getTransFlag()
            if trans_flag:
                self.statuts_message('save succeed')
            else:
                self.statuts_message('save failed')
            dialog.destroy()


    def handler_action_about(self):
        """
        点击about触发此事件，显示此软件相关信息
        :return:
        """
        QMessageBox.about(self, self.tr("About %s"%get_app_name()),
                          self.tr("%s\n\n Author: %s"%(get_app_name(), AUTHOR)))

    def create_views(self):
        """
        在主窗口上创建视图，并创建js与Qt交互的接口
        :return:
        """
        self.view_html = HtmlView()
        self.setCentralWidget(self.view_html)

        self.view_html.jsInit(self, self)

    @pyqtSlot(str, name='test_js_call_qt')
    def test_js_call_qt(self, txt):
        """
        js与Qt的交互函数，点击reload按钮触发此函数
        :param txt: 用户在输入框上输入的Placeholder的新值，以json的格式传送
        :return:
        """
        flag = True
        if txt == 'False':
            flag = False
        else:
            import tensorflow as tf
            value_list = dict()
            data = json.loads(txt)
            datalist = data['placeholder_datalist']
            for key in datalist.keys():
                arr = [int(i) for i in datalist[key]]
                for val in arr:
                    if int(val)<=0:
                        flag = False
                        break
                value_list[key] = tf.ones(arr, tf.int32).get_shape()
        if not flag:
            msg = QMessageBox()
            msg.setText("Error: value is invalid")
            msg.exec_()
        else:
            self.placeholder_valuedict = value_list
            self.graph, self.graph_def = get_placehoder_prevalue(self.graph_def, self.placeholder_valuedict)
            if self.graph == False:
                msg = QMessageBox()
                msg.setText("Error: value is invalid, can't reload with the new value")
                msg.exec_()
            else:
                self.view_html.load(QUrl(''))
                self.statuts_message('clear')
                self.generateHtml_thread = GenerateHtmlThread(self)
                self.generateHtml_thread.done.connect(self.done)
                self.generateHtml_thread.setProgressValue.connect(self.setProgress)
                self.progressDialog = ProgressBarDialog(10,parent=self)
                self.generateHtml_thread.start()
                self.progressDialog.show()

    def setvalue(self):
        """
        给Placeholder设置预设值，并依据预设值计算graph内所有op的输出shape
        :return:
        """
        self.graph, self.graph_def = get_placehoder_prevalue(self.graph_def, self.placeholder_valuedict)
        if self.graph == False:
            self.progressDialog.closeDialog()
            self.graph = None
            self.graph_def = None
            self.placeholder_valuedict = dict()
        else:
            self.generateHtml_thread = GenerateHtmlThread(self)
            self.generateHtml_thread.done.connect(self.done)
            self.generateHtml_thread.setProgressValue.connect(self.setProgress)
            self.generateHtml_thread.start()
            self.generateHtml_thread.exec_()

    def done(self, html):
        """
        TensorFlow的模型加载结束，将加载生成的html代码显示在主窗口，并关闭进度条
        :param html: 加载完成后得到的html代码
        :return:
        """
        self.setProgress('start load html',9)
        self.view_html.load(QUrl.fromLocalFile(html))
        self.statuts_message("Current: %s"%self.file_name)
        self.setProgress('success', 10)
        self.progressDialog.closeDialog()

    def setProgress(self, text, time):
        """
        设置进度条的进度
        :param text: 显示在进度条上方的字符串
        :param time: 进度
        :return:
        """
        self.progressDialog.setProgress(text,time)


class HtmlView(QWebEngineView):

    def contextMenuEvent(self, event):
        menu = QMenu(self)
        for action in [QWebEnginePage.Reload]:
            ac = self.pageAction(action)
            if ac.isEnabled():
                menu.addAction(ac)

        menu.exec_(event.globalPos())

    def jsInit(self, page, mainwd):
        channel = QWebChannel(page)
        channel.registerObject('qt_mainwindow', mainwd)
        self.page().setWebChannel(channel)

def start():
    """
    启动主窗口
    :return:
    """
    app = QApplication(sys.argv)
    main_window = MLCMainWindow(flags=QtCore.Qt.WindowTitleHint)
    sys.exit(app.exec_())